library("reticulate")
## Warning: package 'reticulate' was built under R version 4.0.5
library("knitr")
library("Hmisc")
## Loading required package: lattice
## Loading required package: survival
## Loading required package: Formula
## Loading required package: ggplot2
##
## Attaching package: 'Hmisc'
## The following objects are masked from 'package:base':
##
## format.pval, units
library("DescTools")
##
## Attaching package: 'DescTools'
## The following objects are masked from 'package:Hmisc':
##
## %nin%, Label, Mean, Quantile
library("stringr")
library("egg")
## Loading required package: gridExtra
library("tidyverse")
## ── Attaching packages ─────────────────────────────────────── tidyverse 1.3.1 ──
## ✓ tibble 3.1.6 ✓ purrr 0.3.4
## ✓ tidyr 1.1.4 ✓ dplyr 1.0.7
## ✓ readr 2.1.1 ✓ forcats 0.5.1
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## x dplyr::combine() masks gridExtra::combine()
## x dplyr::filter() masks stats::filter()
## x dplyr::lag() masks stats::lag()
## x dplyr::src() masks Hmisc::src()
## x dplyr::summarize() masks Hmisc::summarize()
# set plotting theme
theme_set(theme_classic() +
theme(text = element_text(size = 24)))
# knitr chunk display options
opts_chunk$set(comment = "",
results = "hold",
fig.show = "hold")
# suppress summarise() grouping warning
options(dplyr.summarise.inform = F)
use_condaenv("plinko")
pd = import("pandas")
df_data = pd$read_pickle("../../data/full_dataset_vision_corrected.xz")
# Filter dataset for analysis of specifc experiment/participant set
df_filtered_data = df_data %>%
filter(!trial %in% c(305, 309))
df_data_judge = df_filtered_data %>%
select(participant, trial, response) %>%
unique()
df_data_rt = df_filtered_data %>%
group_by(participant, trial) %>%
summarise(rt = tail(t, n=1) - head(t, n=1)) %>%
mutate(log_rt = ifelse(rt != 0, log(rt), 0))
df_data_mean_judge_train = df_data_judge %>%
filter(participant %in% seq(1,15)) %>%
group_by(trial) %>%
summarise(hole1 = sum(response == 1)/n(),
hole2 = sum(response == 2)/n(),
hole3 = sum(response == 3)/n()) %>%
pivot_longer(c(hole1, hole2, hole3),
names_to = "hole",
values_to = "human_mean")
df_data_mean_rt_train = df_data_rt %>%
filter(participant %in% seq(1,15)) %>%
mutate(log_rt = ifelse(rt != 0, log(rt), 0)) %>%
group_by(trial) %>%
summarise(mean_rt = mean(rt),
mean_log_rt = mean(log_rt))
compare_bandit_judgment = function(df_human_means, df_model_raw) {
df_model_means = df_model_raw %>%
select(trial, run, judgment) %>%
group_by(trial) %>%
summarise(hole1 = sum(judgment == 0)/n(),
hole2 = sum(judgment == 1)/n(),
hole3 = sum(judgment == 2)/n()) %>%
pivot_longer(c(hole1, hole2, hole3),
names_to = "hole",
values_to = "model_mean")
sq_err = left_join(df_human_means,
df_model_means,
by = c("trial", "hole")) %>%
summarise(sq_err = sum((human_mean - model_mean)^2)) %>%
pull(sq_err)
return(sq_err)
}
compare_bandit_rt = function(df_human_means, df_model_raw) {
df_model_means = df_model_raw %>%
select(trial, run, num_sims, num_looks) %>%
mutate(time_measure = num_sims + num_looks,
log_time = ifelse(time_measure != 0, log(time_measure), time_measure)) %>%
# mutate(time_measure = num_looks) %>%
group_by(trial) %>%
summarise(mean_time = mean(time_measure),
mean_log_time = mean(log(time_measure)))
sq_err = left_join(df_human_means,
df_model_means,
by = c("trial")) %>%
summarise(sq_err = sum((mean_log_rt - mean_log_time)^2)) %>%
pull(sq_err)
return(sq_err)
}
compare_fixed_judgment = function(df_human_means, df_model_raw) {
df_model = df_model_raw %>%
select(trial, hole1, hole2, hole3) %>%
pivot_longer(c(hole1, hole2, hole3),
names_to = "hole",
values_to = "model_score")
sq_err = left_join(df_human_means,
df_model,
by = c("trial", "hole")) %>%
summarise(sq_err = sum((human_mean - model_score)^2)) %>%
pull(sq_err)
return(sq_err)
}
compare_fixed_rt = function(df_human_means, df_model_raw) {
df_model = df_model_raw %>%
select(trial, num_sims, num_looks) %>%
mutate(time_measure = num_sims + num_looks,
log_time_measure = log(time_measure))
sq_err = left_join(df_human_means,
df_model,
by = c("trial")) %>%
summarise(sq_err = sum((mean_log_rt - log_time_measure)^2)) %>%
pull(sq_err)
return(sq_err)
}
thresholds = c()
tradeoffs = c()
bandit_bws = c()
sample_weights = c()
bandit_judge_err = c()
bandit_rt_err = c()
bandit_look_dist = c()
path = "../python/model/model_performance/grid_judgment_rt/"
list_filenames = list.files(path)
for (file in list_filenames) {
file_list = str_split(file, "_")[[1]]
if ("bandit" %in% file_list) {
df_dist = read.csv(paste("../python/model/model_performance/grid_emd/", file, sep = ""))
dist = mean(df_dist$distance)
threshold = as.numeric(file_list[match("threshold", file_list) + 1])
tradeoff = as.numeric(file_list[match("tradeoff", file_list) + 1])
bandit_bw = as.numeric(file_list[match("bw", file_list) + 1])
sample_weight = as.numeric(file_list[match("weight", file_list) + 1])
df_model_raw = read.csv(paste(path, file, sep = ""))
judge_err = compare_bandit_judgment(df_data_mean_judge_train, df_model_raw)
rt_err = compare_bandit_rt(df_data_mean_rt_train, df_model_raw)
thresholds = append(thresholds, threshold)
tradeoffs = append(tradeoffs, tradeoff)
bandit_bws = append(bandit_bws, bandit_bw)
sample_weights = append(sample_weights, sample_weight)
bandit_judge_err = append(bandit_judge_err, judge_err)
bandit_rt_err = append(bandit_rt_err, rt_err)
bandit_look_dist = append(bandit_look_dist, dist)
}
}
df_bandit_performance = tibble(thresholds = thresholds,
tradeoffs = tradeoffs,
bws = bandit_bws,
sample_weights = sample_weights,
judge_err = bandit_judge_err,
rt_err = bandit_rt_err,
look_dist = bandit_look_dist)
df_bandit_performance = df_bandit_performance %>%
mutate(rank_judge = rank(judge_err),
rank_rt = rank(rt_err),
rank_look_dist = rank(look_dist),
combined_ranks = (rank_judge + rank_rt + rank_look_dist)/3) %>%
arrange(combined_ranks)
df_to_show = df_bandit_performance %>%
group_by(thresholds,
tradeoffs) %>%
summarise(ave_score = mean(combined_ranks))
ggplot(df_to_show, mapping = aes(x = thresholds,
y = tradeoffs,
fill = ave_score)) +
geom_tile() +
scale_y_continuous(trans="log10")
df_to_show = df_bandit_performance %>%
group_by(thresholds,
bws) %>%
summarise(ave_score = mean(combined_ranks))
ggplot(df_to_show, mapping = aes(x = thresholds,
y = bws,
fill = ave_score)) +
geom_tile()
df_to_show = df_bandit_performance %>%
group_by(thresholds,
sample_weights) %>%
summarise(ave_score = mean(combined_ranks))
ggplot(df_to_show, mapping = aes(x = thresholds,
y = sample_weights,
fill = ave_score)) +
geom_tile()
df_to_show = df_bandit_performance %>%
group_by(tradeoffs, bws) %>%
summarise(ave_score = mean(combined_ranks))
ggplot(df_to_show, mapping = aes(x = tradeoffs,
y = bws,
fill = ave_score)) +
geom_tile() +
scale_x_continuous(trans="log10")
df_to_show = df_bandit_performance %>%
group_by(tradeoffs, sample_weights) %>%
summarise(ave_score = mean(combined_ranks))
ggplot(df_to_show, mapping = aes(x = tradeoffs,
y = sample_weights,
fill = ave_score)) +
geom_tile() +
scale_x_continuous(trans = "log10")
df_to_show = df_bandit_performance %>%
group_by(bws, sample_weights) %>%
summarise(ave_score = mean(combined_ranks))
ggplot(df_to_show, mapping = aes(x = bws,
y = sample_weights,
fill = ave_score)) +
geom_tile()
df_to_show = df_bandit_performance %>%
pivot_longer(c(thresholds, tradeoffs, bws, sample_weights),
names_to = "param",
values_to = "param_val") %>%
group_by(param, param_val) %>%
summarise(ave_score = mean(combined_ranks))# %>%
# filter(!((param == "thresholds") & (param_val == 1.1)))
ggplot(df_to_show, mapping = aes(x = param_val, y = ave_score)) +
geom_point() +
geom_line() +
facet_wrap(~param, scales = "free")
num_samples_vec = c()
fixed_bws = c()
fixed_judge_err = c()
fixed_rt_err = c()
fixed_look_dist = c()
path = "../python/model/model_performance/grid_judgment_rt/"
list_filenames = list.files(path)
for (file in list_filenames) {
file_list = str_split(file, "_")[[1]]
if ("fixed" %in% file_list) {
df_dist = read.csv(paste("../python/model/model_performance/grid_emd/", file, sep = ""))
dist = mean(df_dist$distance)
num_samples = as.numeric(file_list[match("samples", file_list) + 1])
fixed_bw = as.numeric(file_list[match("bw", file_list) + 1])
df_model_raw = read.csv(paste(path, file, sep = ""))
judge_err = compare_fixed_judgment(df_data_mean_judge_train, df_model_raw)
rt_err = compare_fixed_rt(df_data_mean_rt_train, df_model_raw)
num_samples_vec = append(num_samples_vec, num_samples)
fixed_bws = append(fixed_bws, fixed_bw)
fixed_judge_err = append(fixed_judge_err, judge_err)
fixed_rt_err = append(fixed_rt_err, rt_err)
fixed_look_dist = append(fixed_look_dist, dist)
}
}
df_fixed_performance = tibble(num_samples = num_samples_vec,
bws = fixed_bws,
judge_err = fixed_judge_err,
rt_err = fixed_rt_err,
look_dist = fixed_look_dist)
df_fixed_performance = df_fixed_performance %>%
mutate(rank_judge = rank(judge_err),
rank_rt = rank(rt_err),
rank_look_dist = rank(look_dist),
combined_rank = (rank_judge + rank_rt + rank_look_dist)/3) %>%
arrange(combined_rank)
df_to_show = df_fixed_performance %>%
filter(num_samples %in% seq(10, 150, 10),
bws %in% seq(2, 20, 2))
ggplot(data = df_to_show, mapping = aes(x = num_samples,
y = bws,
fill = combined_rank)) +
geom_tile()
df_to_show = df_fixed_performance %>%
filter(num_samples %in% seq(10,150,10),
bws %in% seq(2,20,2)) %>%
pivot_longer(c(num_samples, bws),
names_to = "param",
values_to = "param_val") %>%
group_by(param, param_val) %>%
summarise(ave_score = mean(combined_rank))
ggplot(df_to_show, mapping = aes(x = param_val,
y = ave_score)) +
geom_line() +
geom_point() +
facet_wrap(~param, scales = "free")
ggplot(data = df_data_rt, mapping = aes(x = rt)) +
geom_histogram(fill = "grey", color = "black") +
ggtitle("Participant Response Times") +
xlab("Response Time (ms)") +
theme(plot.title = element_text(hjust=0.5))
`stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
df_bandit_performance %>%
head(1) %>%
select(thresholds,
tradeoffs,
bws,
sample_weights)
# A tibble: 1 × 4
thresholds tradeoffs bws sample_weights
<dbl> <dbl> <dbl> <dbl>
1 1 0.1 40 500
df_model_judge_rt = read.csv("../python/model/model_performance/grid_judgment_rt/bandit_runs_30_threshold_1.0_tradeoff_0.1_sample_weight_500_bw_40.0_look_probs_0.5_0.8_0.02_0.5_noise_params_0.0_10.0_0.2_0.8_0.2_heuristic_prior_trial_0_150.csv") %>% select(-X)
df_model_mean_judge = df_model_judge_rt %>%
mutate(judgment = judgment + 1,
judgment=factor(judgment)) %>%
group_by(trial, judgment) %>%
summarise(model_mean = n()/(max(run)+1)) %>%
ungroup() %>%
complete(trial, judgment,
fill = list(model_mean=0))
df_data_mean_judge_full = df_data_judge %>%
mutate(hole1 = as.numeric(response == 1),
hole2 = as.numeric(response == 2),
hole3 = as.numeric(response == 3)) %>%
select(-response) %>%
pivot_longer(c(hole1, hole2, hole3),
names_to = "hole",
values_to = "response") %>%
mutate(response = response*100) %>%
group_by(trial, hole) %>%
do(data.frame(rbind(smean.cl.boot(.$response)))) %>%
rename(human_mean = Mean,
lower = Lower,
upper = Upper)
df_human_mean_judge = df_data_mean_judge_full %>%
mutate(hole = as.factor(str_sub(hole, -1, -1))) %>%
rename(judgment = hole)
df_to_show = left_join(df_model_mean_judge,
df_human_mean_judge,
by=c("trial", "judgment")) %>%
mutate(model = "Bandit")
model_cor = round(cor(df_to_show$model_mean, df_to_show$human_mean), digits=2)
model_rmse = round(RMSE(df_to_show$model_mean, df_to_show$human_mean), digits=2)
ggplot(data = df_to_show, mapping = aes(x = model_mean,
y=human_mean)) +
geom_abline(slope = 100,
intercept = 0,
linetype="dotted") +
geom_linerange(mapping = aes(ymin = lower,
ymax = upper),
alpha = 0.2) +
geom_point(alpha=0.5) +
geom_smooth(method = "lm",
formula = y ~ x) +
facet_grid(~ model) +
xlab("Model Prediction") +
ylab("Participant Selection %") +
annotate("text",
label = paste("r: ", model_cor),
x=0.0,
y=100,
hjust=0) +
annotate("text",
label = paste("rmse: ", model_rmse),
x=0.0,
y=95,
hjust = 0) +
theme(plot.title = element_text(size=20, hjust=0.5),
axis.title = element_text(size=16),
axis.text = element_text(size=10))
ggsave("figures/bandit_judgment.pdf", height = 4, width = 5)
df_bandit_judge = df_to_show %>%
mutate(model = "Bandit") %>%
rename(prediction = model_mean)
df_model_mean_rt = df_model_judge_rt %>%
mutate(time_measure = num_sims + num_looks,
log_time = ifelse(time_measure != 0, log(time_measure), time_measure)) %>%
group_by(trial) %>%
summarise(mean_time = mean(time_measure),
mean_log_time = mean(log_time))
df_data_mean_rt = df_data_rt %>%
group_by(trial) %>%
summarise(mean_rt = mean(rt),
mean_log_rt = mean(log(rt)))
df_to_show = left_join(df_model_mean_rt,
df_data_mean_rt,
by = c("trial"))
model_cor = round(cor(df_to_show$mean_time, df_to_show$mean_rt), digits=2)
model_rmse = round(RMSE(df_to_show$mean_time, df_to_show$mean_rt), digits=2)
ggplot(data = df_to_show, mapping = aes(x = mean_time, y = mean_rt)) +
geom_point(alpha = 0.7,
shape=16) +
geom_smooth(method = "lm",
formula = y ~ x) +
# geom_label(mapping = aes(label = trial)) +
ggtitle("Bandit Response Time") +
xlab("Model Mean Looks Across Runs") +
ylab("Participant Mean log Response Time") +
annotate("text",
label = paste("r =", model_cor),
x=20,
y=2500,
hjust=0) +
annotate("text",
label = paste("rmse =", model_rmse),
x=20,
y=2200,
hjust=0) +
theme(plot.title = element_text(size=20, hjust=0.5),
axis.title = element_text(size=14),
axis.text = element_text(size=12))
ggsave("figures/bandit_rt.png", height = 4, width = 5)
df_data_mean_rt = df_data_rt %>%
group_by(trial) %>%
do(data.frame(rbind(smean.cl.boot(.$log_rt)))) %>%
rename(mean_log_rt = Mean,
upper = Upper,
lower = Lower)
df_to_show = left_join(df_model_mean_rt,
df_data_mean_rt,
by = c("trial")) %>%
mutate(model = "Bandit")
model_cor = round(cor(df_to_show$mean_log_time, df_to_show$mean_log_rt), digits=2)
model_rmse = round(RMSE(df_to_show$mean_log_time, df_to_show$mean_log_rt), digits=2)
xvals = c(0.5, 1.0, 1.5, 2.0, 2.5)
yvals = c(7.0, 7.5, 8.0, 8.5)
ggplot(data = df_to_show, mapping = aes(x = mean_log_time, y = mean_log_rt)) +
geom_linerange(mapping = aes(ymin = lower,
ymax = upper),
alpha = 0.15) +
geom_point(alpha = 0.7,
shape=16) +
geom_smooth(method = "lm",
formula = y ~ x) +
facet_grid(~ model) +
xlab("Model Mean log Actions") +
ylab("Mean log Response Time") +
annotate("text",
label = paste("r: ", model_cor),
size = 6,
x=0.5,
y=8.8,
hjust=0) +
annotate("text",
label = paste("rmse: ", model_rmse),
size =6,
x=0.5,
y=8.68,
hjust=0) +
theme(plot.title = element_text(size=20, hjust=0.5),
axis.title = element_text(size=24),
axis.text = element_text(size=18),
plot.margin = margin(10, 0, 0, 10))
ggsave("figures/bandit_log_rt.pdf", height = 4, width = 5)
df_bandit_rt = df_to_show %>%
rename(time_measure = mean_time,
log_time = mean_log_time)
ggplot(df_model_judge_rt, mapping = aes(x = num_looks)) +
geom_histogram(bins=30, fill = "grey", color = "black") +
ggtitle("Bandit Looks Histogram") +
xlab("Number of Looks") +
theme(plot.title = element_text(hjust=0.5))
df_fixed_performance %>%
head(1) %>%
select(num_samples,
bws)
# A tibble: 1 × 2
num_samples bws
<dbl> <dbl>
1 80 12
df_fixed_sample_judge_rt = read.csv("../python/model/model_performance/grid_judgment_rt/fixed_sample_num_samples_80_bw_12.0_look_probs_0.5_0.8_0.02_0.5_noise_params_0.0_10.0_0.2_0.8_0.2_trial_0_150.csv") %>% select(-X)
df_fixed_sample_long = df_fixed_sample_judge_rt %>%
select(trial, hole1, hole2, hole3) %>%
pivot_longer(c(hole1, hole2, hole3),
names_to = "hole",
values_to = "prediction")
df_to_show = df_fixed_sample_long %>%
left_join(df_data_mean_judge_full, by = c("trial", "hole")) %>%
mutate(model = "Fixed Sample")
fixed_sample_cor = round(cor(df_to_show$prediction, df_to_show$human_mean), digits = 2)
fixed_sample_rmse = round(RMSE(df_to_show$prediction, df_to_show$human_mean), digits = 2)
ggplot(df_to_show, mapping = aes(x = prediction, y = human_mean)) +
geom_abline(slope = 100,
intercept = 0,
linetype = "dotted") +
geom_linerange(mapping = aes(ymin = lower,
ymax = upper),
alpha=0.2) +
geom_point(alpha=0.5,
shape=16) +
geom_smooth(method = "lm",
formula = y ~ x) +
annotate("text",
label = paste("r:", fixed_sample_cor),
x = 0.0,
y = 100,
hjust = 0) +
annotate("text",
label = paste("rmse:", fixed_sample_rmse),
x = 0.0,
y = 95,
hjust = 0) +
facet_grid(~ model) +
xlab("Model Prediction") +
ylab("Participant Mean Judgment") +
theme(plot.title = element_text(size=20,
hjust=0.5),
axis.title = element_text(size=16),
axis.text = element_text(size=10))
ggsave("figures/fixed_sample_judgments.pdf", height=4, width=5)
df_fixed_judge = df_to_show %>%
mutate(judgment = as.factor(str_sub(hole, -1, -1)),
model = "Uniform Sampler") %>%
select(-hole)
df_to_show = df_fixed_sample_judge_rt %>%
select(trial, num_sims, num_looks) %>%
mutate(time_measure = num_sims + num_looks,
log_time = log(time_measure)) %>%
left_join(df_data_mean_rt, by = "trial") %>%
mutate(model = "Uniform Sampler")
fixed_sample_rt_cor = round(cor(df_to_show$log_time, df_to_show$mean_log_rt), digits = 2)
fixed_sample_rt_rmse = round(RMSE(df_to_show$log_time, df_to_show$mean_log_rt), digits = 2)
ggplot(data = df_to_show, mapping = aes(x = log_time, y = mean_log_rt)) +
geom_linerange(mapping = aes(ymin = lower,
ymax = upper),
alpha = 0.15) +
geom_point(alpha=0.5,
shape=16) +
geom_smooth(method = "lm",
formula = y ~ x) +
facet_grid(~ model) +
xlab("Model Prediction") +
ylab("Participant Mean log \n Response Time") +
annotate("text",
label = paste("r:", fixed_sample_rt_cor),
x=6.3,
y=8.8,
size=6,
hjust=0) +
annotate("text",
label = paste("rmse:", fixed_sample_rt_rmse),
x=6.3,
y=8.68,
size=6,
hjust=0) +
theme(plot.title = element_text(size=20, hjust=0.5),
axis.title = element_text(size=24),
axis.text = element_text(size=18),
plot.margin = margin(10,0,0,10))
ggsave("figures/fixed_sample_rt.pdf", height=4, width = 5)
df_fixed_rt = df_to_show %>%
select(-c(num_looks, num_sims))
ggplot(df_fixed_sample_judge_rt, mapping = aes(x = num_looks))+
geom_histogram(fill = "grey", color = "black") +
ggtitle("Fixed Sample Looks Histogram") +
xlab("Number of Looks") +
theme(plot.title = element_text(hjust=0.5))
`stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
# Cogsci Figures
df_to_show = rbind(df_bandit_judge,
df_fixed_judge)
df_sum_stat = df_to_show %>%
group_by(model) %>%
summarise(r = round(cor(prediction, human_mean), digits = 2),
rmse = round(RMSE(prediction, human_mean), digits = 2))
ggplot(df_to_show, mapping = aes(x = prediction, y = human_mean)) +
geom_abline(slope = 100,
intercept = 0,
linetype = "dotted") +
geom_linerange(mapping = aes(ymin = lower,
ymax = upper),
alpha = 0.2) +
geom_point(alpha = 0.5) +
geom_smooth(method = "lm") +
geom_text(data = df_sum_stat,
x = 0.0,
y = 100,
size = 6,
hjust = 0,
mapping = aes(label = paste("r: ", r, sep = ""))) +
geom_text(data = df_sum_stat,
x = 0.0,
y = 93,
size = 6,
hjust = 0,
mapping = aes(label = paste("rmse: ", rmse, sep = ""))) +
facet_wrap(~ model) +
xlab("Model Prediction") +
ylab("Participant % Selection") +
theme(plot.title = element_text(size=20, hjust=0.5),
axis.title = element_text(size=24),
axis.text = element_text(size=18),
panel.spacing = unit(2, "lines"))
`geom_smooth()` using formula 'y ~ x'
ggsave("figures/model_judgment.pdf",
width = 10,
height = 4)
`geom_smooth()` using formula 'y ~ x'
df_to_show = rbind(df_bandit_rt,
df_fixed_rt)
df_sum_stat = df_to_show %>%
group_by(model) %>%
summarise(r = round(cor(log_time, mean_log_rt), digits = 2),
rmse = round(RMSE(log_time, mean_log_rt), digits = 2))
ggplot(df_to_show, mapping = aes(x = log_time,
y = mean_log_rt)) +
geom_linerange(mapping = aes(ymin = lower,
ymax = upper),
alpha = 0.3) +
geom_point(alpha = 0.7,
shape = 16) +
geom_smooth(method = "lm") +
geom_text(data = df_sum_stat,
x = 0.5,
y = 8.7,
hjust = 0,
size = 4,
mapping = aes(label = paste("r:", r))) +
geom_text(data = df_sum_stat,
x = 0.5,
y = 8.6,
hjust = 0,
size = 4,
mapping = aes(label = paste("rmse:", rmse))) +
facet_wrap(~ model,
scales = "free_x") +
xlab("Model Prediction") +
ylab("Mean log Response Time") +
theme(plot.title = element_text(size=20, hjust=0.5),
axis.title = element_text(size=14),
axis.text = element_text(size=12),
panel.spacing = unit(2, "lines"))
`geom_smooth()` using formula 'y ~ x'
df_emd_bandit = read.csv("../python/model/model_performance/emd/top_bandit.csv") %>%
select(trial, distance) %>%
mutate(model = "Bandit")
df_emd_fixed_sample = read.csv("../python/model/model_performance/emd/top_fixed_sample.csv") %>%
select(trial, distance) %>%
mutate(trial = factor(trial),
model = "Uniform Sampler")
df_emd_baseline = read.csv("../python/model/model_performance/emd/emd_baseline.csv") %>%
select(-X) %>%
mutate(trial = factor(trial),
model = "Baseline")
to_highlight = c()
set.seed(1)
df_to_show = rbind(df_emd_bandit, df_emd_fixed_sample, df_emd_baseline) %>%
mutate(model = factor(model,
levels = c("Bandit", "Uniform Sampler", "Baseline"),
labels = c(1,2,3)),
model = as.numeric(as.character(model)),
highlight = trial %in% to_highlight,
model_jitter = model + runif(n = n(),
min = -0.15,
max = 0.15))
# ggplot(df_to_show, mapping = aes(x = model, y = distance)) +
ggplot(df_to_show, mapping = aes(x = model,
y = distance,
color = highlight)) +
geom_line(mapping = aes(x = model_jitter, group = trial),
alpha = 0.05) +
geom_point(mapping = aes(x = model_jitter),
alpha = 0.5,
shape=16,
size=3) +
stat_summary(fun.data = "mean_cl_boot", color = "red", size=0.8) +
scale_x_continuous(breaks = c(1,2,3), labels = c("Bandit", "Uniform Sampler", "Baseline")) +
scale_color_manual(values = c("black", "magenta3")) +
ylab("Earth Mover's Distance") +
theme(legend.title = element_blank(),
legend.position = "none",
axis.title.y = element_text(size=24),
axis.title.x = element_blank(),
axis.text = element_text(size=20))
ggsave("figures/emd_comparison.pdf",
height = 5,
width = 10)
df_emd = rbind(df_emd_bandit,
df_emd_fixed_sample,
df_emd_baseline)
df_emd %>%
group_by(model) %>%
do(data.frame(rbind(round(smean.cl.boot(.$distance), 2))))
# A tibble: 3 × 4
# Groups: model [3]
model Mean Lower Upper
<chr> <dbl> <dbl> <dbl>
1 Bandit 51.2 48.8 53.7
2 Baseline 118. 116. 121.
3 Uniform Sampler 77.1 74.6 79.7
df_emd = rbind(df_emd_bandit,
df_emd_fixed_sample)
df_emd %>%
mutate(model = ifelse(model == "Uniform Sampler", "uniform_sampler", "bandit")) %>%
pivot_wider(names_from = model,
values_from = distance) %>%
mutate(diff = uniform_sampler - bandit) %>%
arrange(desc(diff))
# A tibble: 150 × 4
trial bandit uniform_sampler diff
<chr> <dbl> <dbl> <dbl>
1 190 32.0 112. 79.5
2 67 37.1 116. 79.0
3 72 33.9 109. 74.9
4 12 49.5 113. 63.7
5 158 35.3 96.4 61.1
6 114 49.6 108. 58.8
7 46 25.5 84.1 58.5
8 223 36.5 95.0 58.5
9 254 34.3 91.8 57.6
10 20 35.5 90.3 54.8
# … with 140 more rows
ggplot(df_to_show, mapping = aes(x = distance, fill = model)) +
geom_histogram(color = "black") +
facet_wrap(~model, nrow = 3)
`stat_bin()` using `bins = 30`. Pick better value with `binwidth`.